package com.itmuch.cloud.study.user.service;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;

import org.springframework.http.MediaType;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import lombok.extern.slf4j.Slf4j;

/**
 * Server-Sent Events <BR>
 * https://blog.csdn.net/hhl18730252820/article/details/126244274
 */
@Slf4j
public class SSEServer
{
    private static List<SseEmitter> sseEmitters = new CopyOnWriteArrayList<>();
    
    public static SseEmitter connect()
    {
        SseEmitter sseEmitter = new SseEmitter(0L); // 设置超时时间，0表示不过期，默认是30秒，超过时间未完成会抛出异常
        
        // 注册回调
        sseEmitter.onCompletion(completionCallBack(sseEmitter));
        sseEmitter.onError(errorCallBack(sseEmitter));
        sseEmitter.onTimeout(timeOutCallBack(sseEmitter));
        sseEmitters.add(sseEmitter);
        log.info("###### create new sse connect, count: {}", sseEmitters.size());
        return sseEmitter;
    }
    
    public static void batchSendMessage(String message)
    {
        sseEmitters.forEach(it -> {
            try
            {
                it.send(message, MediaType.APPLICATION_JSON);
            }
            catch (IOException e)
            {
                log.error("send message error: {}", e.getMessage());
                remove(it);
            }
        });
    }
    
    /**
     * 指定name,发送message
     * 
     * @param name
     * @param message 普通字符串或json数据
     */
    public static void batchSendMessage(String name, String message)
    {
        sseEmitters.forEach(it -> {
            try
            {
                it.send(SseEmitter.event().name(name).data(message));
            }
            catch (IOException e)
            {
                log.error("send message error: {}", e.getMessage());
                remove(it);
            }
        });
    }
    
    public static void remove(SseEmitter s)
    {
        if (sseEmitters.contains(s))
        {
            sseEmitters.remove(s);
            log.info("###### remove SseEmitter, count: {}", sseEmitters.size());
        }
    }
    
    private static Runnable completionCallBack(SseEmitter s)
    {
        return () -> {
            log.info("结束连接");
            remove(s);
        };
    }
    
    private static Runnable timeOutCallBack(SseEmitter s)
    {
        return () -> {
            log.info("连接超时");
            remove(s);
        };
    }
    
    private static Consumer<Throwable> errorCallBack(SseEmitter s)
    {
        return throwable -> {
            log.error("连接异常");
            remove(s);
        };
    }
}